Source code for lmcsc.transformation_type

from collections import defaultdict, Counter
from math import ceil
import os
import pickle

from pypinyin import pinyin, Style
from pypinyin_dict.pinyin_data import ktghz2013
from pypinyin_dict.phrase_pinyin_data import large_pinyin
import json
from tqdm import tqdm

from lmcsc.common import PUNCT, OOV_CHAR, consonant_inits, reAlNUM


# load better pinyin data
ktghz2013.load()
# load better phrase pinyin data
large_pinyin.load()


import yaml

[docs] class TransformationType: r""" A class for handling various types of transformations on input sequences, particularly for Chinese text. This class provides functionality to identify and categorize different types of character transformations, such as similar shapes, similar pronunciations, and common confusions in Chinese characters. Args: vocab (`dict`): A dictionary mapping tokens to their indices in the vocabulary. is_bytes_level (`bool`): Flag indicating whether the input is at the byte level. distortion_type_prior_priority (`list`, *optional*): A list specifying the priority order of distortion types. If not provided, a default order is used. config_path (`str`, *optional*, defaults to 'configs/default_config.yaml'): Path to the configuration file containing paths to various dictionaries and resources. Attributes: similar_shape_dict (`dict`): Dictionary of characters with similar shapes. shape_confusion_dict (`dict`): Dictionary of characters prone to shape-based confusion. similar_consonant_dict (`dict`): Dictionary of similar consonants in pinyin. similar_vowel_dict (`dict`): Dictionary of similar vowels in pinyin. similar_spell_dict (`dict`): Dictionary of characters with similar spellings. near_spell_dict (`dict`): Dictionary of characters with near spellings. prone_to_confusion_dict (`dict`): Dictionary of characters prone to confusion. vocab (`dict`): The input vocabulary. is_bytes_level (`bool`): Flag indicating byte-level processing. distortion_type_priority_order (`list`): Ordered list of distortion type priorities. distortion_type_priority (`dict`): Dictionary mapping distortion types to their priorities. Note: This class relies on various external resources and dictionaries for Chinese language processing, which should be properly configured in the specified config file. """ def __init__(self, vocab, is_bytes_level, distortion_type_prior_priority=None, config_path='configs/default_config.yaml'): r""" Initializes the TransformationType class. Args: vocab (dict): A dictionary mapping tokens to their indices in the vocabulary. is_bytes_level (bool): Flag indicating whether the input is at the byte level. distortion_type_prior_priority (list, optional): A list specifying the priority order of distortion types. If not provided, a default order is used. config_path (str, optional): Path to the configuration file containing paths to various dictionaries and resources. Defaults to 'configs/default_config.yaml'. """ # Load configuration from YAML file with open(config_path, 'r') as config_file: config = yaml.safe_load(config_file) file_paths = config['transformation_type_paths'] self.config = config # Load dictionary of characters with similar shapes self.similar_shape_dict = self.load_dict(file_paths['similar_shape_dict']) # Load set of characters that should not be missing if 'length_immutable_chars' in file_paths: self.length_immutable_chars = set(self.load_list(file_paths['length_immutable_chars'])) else: self.length_immutable_chars = set() # Load dictionary of characters with similar strokes self.shape_confusion_dict = self.load_dict(file_paths['shape_confusion_dict']) # Load dictionaries of similar consonants and vowels self.similar_consonant_dict = self.load_dict(file_paths['similar_consonant_dict']) self.similar_vowel_dict = self.load_dict(file_paths['similar_vowel_dict']) # Load dictionaries of similar and near spellings self.similar_spell_dict, self.near_spell_dict = self.load_similar_spell_dict( file_paths['pinyin_distance_matrix'] ) # Load dictionary of characters prone to confusion self.prone_to_confusion_dict = self.load_dict(file_paths['prone_to_confusion_dict']) # Store vocabulary, byte-level flag, and shape similarity threshold self.vocab = vocab self.is_bytes_level = is_bytes_level self.allow_insert_punct = os.getenv("ALLOW_INSERT_PUNCT", "false").lower() == "true" print(f"allow_insert_punct: {self.allow_insert_punct}") # Build distortion type priority self.build_distortion_type_priority(distortion_type_prior_priority) # Build inverse index for efficient lookup self.build_inverse_index()
[docs] def build_distortion_type_priority(self, distortion_type_prior_priority): r""" Builds the distortion type priority. Args: distortion_type_prior_priority (list, optional): A list specifying the priority order of distortion types. If not provided, a default order is used. """ default_distortion_type_prior_priority_order = self.config['distortion_type_prior_priority_order'] if distortion_type_prior_priority is None: self.distortion_type_priority_order = default_distortion_type_prior_priority_order distortion_type_prior_priority = self.config['distortion_type_prior_priority'] else: self.distortion_type_priority_order = [ distortion_type for distortion_type in distortion_type_prior_priority if distortion_type in default_distortion_type_prior_priority_order ] self.distortion_type_priority = { distortion_type: len(self.distortion_type_priority_order) - i for i, distortion_type in enumerate(distortion_type_prior_priority) }
[docs] def load_dict(self, file_name): r""" Loads and returns a JSON dictionary from a file. Args: file_name (str or list): The name of the file to load the dictionary from. Returns: dict: The loaded dictionary. """ final_dict = {} if isinstance(file_name, str): file_name = [file_name] for file in file_name: with open(file, "r", encoding="utf-8") as f: final_dict.update(json.load(f)) return final_dict
[docs] def load_list(self, file_name): r""" Loads and returns a list from a json file. """ final_list = [] if isinstance(file_name, str): file_name = [file_name] for file in file_name: with open(file, "r", encoding="utf-8") as f: final_list.extend(json.load(f)) return final_list
[docs] def load_similar_spell_dict(self, file_name): r""" Loads the spell distance matrix from a pickle file. Args: file_name (str): The name of the file to load the spell distance matrix from. Returns: tuple: A tuple containing two dictionaries: similar_spell_dict and near_spell_dict. """ self.spell_distance_matrix = pickle.load(open(file_name, "rb")) similar_spell_dict = defaultdict(set) near_spell_dict = defaultdict(set) # Populate similar_spell_dict based on distance threshold for pair, distance in self.spell_distance_matrix.items(): if distance <= 1.0: similar_spell_dict[pair[0]].add(pair[1]) similar_spell_dict[pair[1]].add(pair[0]) return similar_spell_dict, near_spell_dict
[docs] def bag_of_chars_hash(self, token): # Create a hash of character counts in a token counter = Counter(token) item_seq = sorted(counter.items(), key=lambda x: x[0]) # Convert to a string representation item_seq = "-".join([f"{k}:{v}" for k, v in item_seq]) return item_seq
[docs] def init_pinyin_of_token_hash(self, token_consonants): # Create a hash of initial pinyin of a token new_token_consonants = [] for c in token_consonants: if len(c) > 0 and len(c[0]) > 0: if c[0][0] in consonant_inits: new_token_consonants.append(c[0][0]) else: new_token_consonants.append('1') else: new_token_consonants.append('0') return "-".join(new_token_consonants)
[docs] def build_inverse_index(self): r""" Builds inverse indices for efficient lookup. This method constructs multiple index dictionaries that map certain features of tokens (such as pinyin, character positions, etc.) to the indices of tokens in the vocabulary. These indices are used to efficiently perform lookups based on various transformation types, such as identical characters, similar pinyin, similar shapes, etc. It processes each token in the vocabulary and builds indices for: - **Identical characters at specific positions** (`identical_char_index`) - **Characters prone to confusion at specific positions** (`prone_to_confusion_char_index`) - **Tokens sharing the same pinyin at specific positions** (`same_pinyin_index`) - **Tokens with similar pinyin at specific positions** (`similar_pinyin_index`) - **Tokens with pinyin that are similar due to spelling errors at specific positions** (`other_similar_pinyin_index`) - **Tokens with characters of similar shapes at specific positions** (`similar_shape_index`) - **Tokens with characters of shapes that are confused at specific positions** (`other_similar_shape_index`) - **Identical tokens** (`identical_token_index`) Additionally, it keeps track of: - **Token lengths** (`token_length`) - **Mapping of indices back to tokens** (`idx_to_token`) - **Set of all unique characters in tokens** (`char_set`) These indices facilitate quick retrieval of tokens based on various linguistic and orthographic features. """ # Initialize various index dictionaries for different types of transformations self.identical_char_index = defaultdict(set) # Maps (position, character) -> set of token indices self.prone_to_confusion_char_index = defaultdict(set) # Maps (position, confusing character) -> set of token indices self.same_pinyin_index = defaultdict(set) # Maps (position, pinyin) -> set of token indices self.similar_pinyin_index = defaultdict(set) # Maps (position, similar pinyin) -> set of token indices self.init_pinyin_index = defaultdict(set) self.other_similar_pinyin_index = defaultdict(set) # Maps (position, other similar pinyin) -> set of token indices self.similar_shape_index = defaultdict(set) # Maps (position, character with similar shape) -> set of token indices self.other_similar_shape_index = defaultdict(set) # Maps (position, character with shape confusion) -> set of token indices self.reorder_index = defaultdict(set) self.missing_char_index = defaultdict(set) self.identical_token_index = defaultdict(set) # Maps token -> set of indices self.idx_to_token = {} # Maps token index -> token self.token_length = {} # Maps token index -> token length self.char_set = set() # Set of all unique characters in tokens self.is_chinese_token = {} # Maps token index -> is_chinese_token # Iterate through all vocabulary items for k, idx in tqdm(self.vocab.items()): ori_token = k # Original token (could be bytes or string) # If tokens are at byte level, attempt to decode them if self.is_bytes_level: try: # Try to decode byte-level token to UTF-8 string ori_token = k.decode("utf-8") except UnicodeDecodeError: # If decoding fails, handle each byte separately for i, byte in enumerate(ori_token): # Map the byte at position i to the token index self.identical_char_index[(i, byte)].add(idx) # Approximate the length of the token (assuming Chinese characters are 3 bytes) self.token_length[idx] = len(ori_token) / 3 continue # Skip to the next token # Handle special tokens that represent bytes (e.g., '<0x00>') if ( len(ori_token) == 6 and ori_token.startswith("<0x") and ori_token.endswith(">") ): # Assign a fractional length to represent a single byte character self.token_length[idx] = 1 / 3 else: # For regular tokens, use the length of the token string self.token_length[idx] = len(ori_token) token = ori_token # Use the potentially decoded token # Build identical token indices self.identical_token_index[token].add(idx) self.idx_to_token[idx] = token # Map index back to token # Build reorder index self.reorder_index[self.bag_of_chars_hash(token)].add(idx) # Build character-level indices for i, char in enumerate(token): self.char_set.add(char) # Keep track of unique characters # Map (position, character) to token index self.identical_char_index[(i, char)].add(idx) # Handle characters that are frequently confused with others for equal_char in self.prone_to_confusion_dict.get(char, []): # Map (position, confusing character) to token index self.prone_to_confusion_char_index[(i, equal_char)].add(idx) # Build missing character index if char not in self.length_immutable_chars and char not in PUNCT: key = token[:i] + token[i + 1 :] self.missing_char_index[key].add(idx) # Build pinyin-related indices # Get list of possible pinyins for each character in the token token_pinyins = pinyin(token, style=Style.NORMAL, heteronym=True) # If the token is not converted (e.g., not Chinese characters), skip pinyin indices if len(token_pinyins) == 1 and token_pinyins[0][0] == token: # TODO: this skip may cause the short circuit of the shape confusion if self.allow_insert_punct and (token in PUNCT and len(token) == 1): self.is_chinese_token[idx] = True else: self.is_chinese_token[idx] = False continue # Proceed to the next token else: self.is_chinese_token[idx] = True # For each character position and its possible pinyins for i, ps in enumerate(token_pinyins): for p in ps: # Map (position, pinyin) to token index (exact pinyin match) self.same_pinyin_index[(i, p)].add(idx) # For pinyins that are similar due to spelling errors for similar_pinyin in self.similar_spell_dict.get(p, []): # Map (position, similar pinyin) to token index self.other_similar_pinyin_index[(i, similar_pinyin)].add(idx) # Build similar pinyin indices based on consonants and vowels token_consonants = pinyin(token, style=Style.INITIALS, heteronym=True) # Possible consonants token_vowels = pinyin(token, style=Style.FINALS, heteronym=True) # Possible vowels for i, (consonant_variants, vowel_variants) in enumerate(zip(token_consonants, token_vowels)): # Initialize sets for similar consonants and vowels similar_consonants = set(consonant_variants) similar_vowels = set(vowel_variants) # Expand similar consonants based on predefined mappings for c in consonant_variants: similar_consonants.update(self.similar_consonant_dict.get(c, [c])) # Expand similar vowels based on predefined mappings for v in vowel_variants: similar_vowels.update(self.similar_vowel_dict.get(v, [v])) # Combine similar consonants and vowels to generate fuzzy pinyins for c in similar_consonants: for v in similar_vowels: # Concatenate consonant and vowel to form pinyin fuzzy_pinyin = c + v # Map (position, fuzzy pinyin) to token index self.similar_pinyin_index[(i, fuzzy_pinyin)].add(idx) # # Build initial pinyin index if len(token) > 1: token_consonants = pinyin(token, style=Style.NORMAL, heteronym=False) key = self.init_pinyin_of_token_hash(token_consonants) self.init_pinyin_index[key].add(idx) # Build shape-related indices for i, char in enumerate(token): # If character has similar-shaped characters if char in self.similar_shape_dict: for similar_char in self.similar_shape_dict[char]: # Map (position, similar-shaped character) to token index self.similar_shape_index[(i, similar_char)].add(idx) # If character is existing in shape confusion dict if char in self.shape_confusion_dict: for similar_char in self.shape_confusion_dict[char]: # Map (position, shape-confused character) to token index self.other_similar_shape_index[(i, similar_char)].add(idx)
[docs] def handle_oov_characters(self, observed_sequence): r""" Handles out-of-vocabulary (OOV) characters. Args: observed_sequence (str or bytes): The observed sequence containing OOV characters. Returns: dict: A dictionary mapping token indices to their corresponding transformation types. """ oov_transformation = {} if isinstance(observed_sequence[0], int): assert isinstance(observed_sequence, bytes) token_bytes = observed_sequence else: token_bytes = observed_sequence[0].encode("utf-8") idx = None for i in range(len(token_bytes)): tmp_byte = token_bytes[: i + 1] if tmp_byte in self.vocab: idx = self.vocab[tmp_byte] assert idx is not None, f"{observed_sequence} {token_bytes}" oov_transformation[idx] = ("IDT", ) return oov_transformation
[docs] def get_pinyin_data(self, observed_sequence): r""" Gets pinyin data for the observed sequence. Args: observed_sequence (str): The observed sequence. Returns: tuple: A tuple containing two elements: - list: A list of pinyin representations for the observed sequence. - list: A list of consonant representations for the observed sequence. """ try: token_pinyins = pinyin(observed_sequence, style=Style.NORMAL, heteronym=True, errors=lambda x: [char for char in x]) token_consonants = pinyin(observed_sequence, style=Style.NORMAL, heteronym=False, errors=lambda x: [char for char in x]) except: return None, None return token_pinyins, token_consonants
[docs] def handle_identical_characters(self, i, char, token_transformation): r""" Handles identical characters. Args: i (int): The index of the character in the observed sequence. char (str): The character in the observed sequence. token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. """ for idx in self.identical_char_index.get((i, char), []): token_transformation[idx].setdefault(i, "IDT")
[docs] def handle_prone_to_confusion(self, i, char, token_transformation): r""" Handles characters prone to confusion. Args: i (int): The index of the character in the observed sequence. char (str): The character in the observed sequence. token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. """ for idx in self.prone_to_confusion_char_index.get((i, char), []): token_transformation[idx].setdefault(i, "PTC")
[docs] def is_punctuation_or_space(self, char): r""" Checks if a character is a punctuation or space. Args: char (str): The character to check. Returns: bool: True if the character is a punctuation or space, False otherwise. """ # return char in PUNCT and char != "_" return char in PUNCT
[docs] def handle_continuous_punctuation_or_space(self, i, observed_sequence, token_transformation): r""" Handles continuous punctuation or space. """ if not isinstance(observed_sequence, str): return if i == 0: for l in range(2, len(observed_sequence) + 1): key = observed_sequence[:l] for idx in self.identical_token_index.get(key, []): token_transformation[idx] = {k: "IDT" for k in range(len(key))}
[docs] def handle_redundant_before_punctuation_or_space(self, i, char, observed_sequence, token_transformation, original_token_length): r""" Handles redundant characters before punctuation or space. """ if i >= 1 and char in PUNCT: removed_char = observed_sequence[:i] for l in range(1, len(observed_sequence) - i): key = observed_sequence[i:i+l] for idx in self.identical_token_index.get(key, []): token_transformation[idx] = {k: "RED" for k in range(len(removed_char))} original_token_length[idx] = len(key.encode("utf-8")) if self.is_bytes_level else len(key)
[docs] def handle_same_pinyin(self, i, token_pinyins, token_transformation): r""" Handles characters with the same pinyin. Args: i (int): The index of the character in the observed sequence. token_pinyins (list): A list of pinyin representations for the observed sequence. token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. """ for p in token_pinyins[i]: for idx in self.same_pinyin_index.get((i, p), []): token_transformation[idx].setdefault(i, "SAP")
[docs] def handle_reorder_tokens(self, i, part_observed_sequence, token_transformation): r""" Handles reordered tokens. Args: i (int): The index of the character in the observed sequence. part_observed_sequence (str): The observed sequence without the character at index i. token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. """ if i > 0: current_hash = self.bag_of_chars_hash(part_observed_sequence) for idx in self.reorder_index.get(current_hash, []): if (idx not in token_transformation or len(token_transformation[idx]) < self.token_length[idx] or not set(token_transformation[idx].values()).issubset({"IDT", })): token_transformation[idx] = {k: "ROR" for k in range(int(self.token_length[idx]))}
[docs] def handle_initial_pinyin_match(self, i, part_consonants, token_transformation): r""" Handles initial pinyin match. For example, "jq" -> "机器", "精确", ... Args: i (int): The index of the character in the observed sequence. part_consonants (list): A list of consonant representations for the observed sequence. token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. """ if i > 0: current_hash = self.init_pinyin_of_token_hash(part_consonants) imp_type_priority = self.distortion_type_priority["IMP"] for idx in self.init_pinyin_index.get(current_hash, []): if idx not in token_transformation: token_transformation[idx] = {k: "IMP" for k in range(int(self.token_length[idx]))} elif len(token_transformation[idx]) < self.token_length[idx]: token_length = int(self.token_length[idx]) token_trans = token_transformation.get(idx, {}) for range_idx in range(token_length): if range_idx not in token_trans or self.distortion_type_priority[token_trans[range_idx]] < imp_type_priority: token_trans[range_idx] = "IMP" token_transformation[idx] = token_trans
[docs] def handle_similar_pinyin(self, i, token_pinyins, token_transformation): r""" Handles characters with similar pinyin. Args: i (int): The index of the character in the observed sequence. token_pinyins (list): A list of pinyin representations for the observed sequence. token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. """ for p in token_pinyins[i]: for idx in self.similar_pinyin_index.get((i, p), []): token_transformation[idx].setdefault(i, "SIP")
[docs] def handle_similar_shape(self, i, char, token_transformation): r""" Handles characters with similar shapes. Args: i (int): The index of the character in the observed sequence. char (str): The character in the observed sequence. token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. """ for idx in self.similar_shape_index.get((i, char), []): token_transformation[idx].setdefault(i, "SIS")
[docs] def handle_other_pinyin_error(self, i, token_pinyins, token_transformation): r""" Handles other pinyin errors. Args: i (int): The index of the character in the observed sequence. token_pinyins (list): A list of pinyin representations for the observed sequence. token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. """ for p in token_pinyins[i]: for idx in self.other_similar_pinyin_index.get((i, p), []): token_transformation[idx].setdefault(i, "OTP")
[docs] def handle_redundant_character_inside_token(self, i, part_observed_sequence, token_transformation, original_token_length): r""" Handles redundant characters inside the token. Args: i (int): The index of the character in the observed sequence. part_observed_sequence (str): The observed sequence without the character at index i. token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. original_token_length (dict): A dictionary mapping token indices to their original lengths. """ if i > 0: for j in range(1, i): key = part_observed_sequence[:j] + part_observed_sequence[j+1:] removed_char = part_observed_sequence[j] if removed_char in PUNCT or reAlNUM.match(removed_char): # do not remove punctuation or number and english letters continue if j == 0 and removed_char in self.length_immutable_chars: continue for idx in self.identical_token_index.get(key, []): if (idx not in token_transformation or len(token_transformation[idx]) < self.token_length[idx] or not set(token_transformation[idx].values()).issubset({"IDT", })): this_token_transformation = {k: "IDT" for k in range(int(self.token_length[idx]))} this_token_transformation[j] = "RED" token_transformation[idx] = this_token_transformation original_token_length[idx] = len(part_observed_sequence.encode("utf-8")) if self.is_bytes_level else len(part_observed_sequence)
[docs] def handle_redundant_characters(self, observed_sequence, token_transformation, original_token_length): # Case 2: Redundant sequences # The original code is too strict if not isinstance(observed_sequence, str): return for i in range(1, 5): for j in range(1, len(observed_sequence) - i): key = observed_sequence[i:i+j] if len(key) == 0: continue removed_chars = observed_sequence[:i] if any([reAlNUM.match(char) for char in removed_chars]) or any([char in PUNCT for char in removed_chars]): # do not remove number and english letters continue for idx in self.identical_token_index.get(key, []): if idx not in token_transformation: token_transformation[idx] = {k: "RED" for k in range(len(removed_chars))} replaced_chars = removed_chars + key original_token_length[idx] = len(replaced_chars.encode("utf-8")) if self.is_bytes_level else len(replaced_chars)
[docs] def handle_other_similar_shape(self, i, char, token_transformation): r""" Handles characters with other similar shapes. Args: i (int): The index of the character in the observed sequence. char (str): The character in the observed sequence. token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. """ for idx in self.other_similar_shape_index.get((i, char), []): token_transformation[idx].setdefault(i, "OTS")
[docs] def handle_missing_characters(self, observed_sequence, broken_token_transformation, original_token_length_for_broken): r""" Handles missing characters. Args: observed_sequence (str): The observed sequence. broken_token_transformation (set): A set of token indices with missing characters. original_token_length_for_broken (dict): A dictionary mapping token indices to their original lengths. """ for i in range(len(observed_sequence)): part_observed_sequence = observed_sequence[:i+1] for idx in self.missing_char_index.get(part_observed_sequence, []): broken_token_transformation.add(idx) original_token_length_for_broken[idx] = len(part_observed_sequence.encode("utf-8")) if self.is_bytes_level else len(part_observed_sequence)
[docs] def filter_and_finalize_transformations(self, token_transformation, broken_token_transformation, original_token_length, original_token_length_for_broken): r""" Filters and finalizes the transformations. Args: token_transformation (dict): A dictionary mapping token indices to their corresponding transformation types. Returns: dict: A dictionary mapping token indices to their finalized transformation types. """ new_transformation = {} potential_transformation = {} for idx, transformation in token_transformation.items(): if transformation.get(0) == "ROR": new_transformation[idx] = ("ROR", ) + ("IDT",) * (self.token_length[idx] - 1) elif len(transformation) == self.token_length[idx] or "RED" in set(transformation.values()): new_transformation[idx] = tuple(transformation.values()) elif self.token_length[idx] - len(transformation) <= 2: potential_transformation[idx] = transformation for idx in broken_token_transformation: if idx not in new_transformation: new_transformation[idx] = ("MIS",) + ("IDT",) * (self.token_length[idx] - 1) original_token_length[idx] = original_token_length_for_broken[idx] for idx, transformation in potential_transformation.items(): if idx not in new_transformation: if len(transformation) >= 1: new_transformation[idx] = tuple(transformation.values()) + ("UNR",) * ceil(self.token_length[idx] - len(transformation)) return new_transformation
[docs] def handle_final_oov(self, observed_sequence): r""" Handles out-of-vocabulary (OOV) characters as a final step. Args: observed_sequence (str): The observed sequence containing OOV characters. Returns: dict: A dictionary mapping token indices to their corresponding transformation types. """ new_transformation = {} for idx in self.identical_char_index[(0, OOV_CHAR)]: if self.token_length[idx] == 1: new_transformation[idx] = ("IDT", ) # IDT: Identical character (OOV) return new_transformation
[docs] def get_transformation_type(self, observed_sequence: str): r""" Determine the transformation types for all tokens in the vocabulary to the observed sequence. This method analyzes the input sequence and identifies various types of character transformations that may have occurred, such as character substitutions, pinyin-based errors, or shape-based confusions. It returns a mapping of token indices to their corresponding transformation types. Args: observed_sequence (str): The input sequence of characters to be analyzed for transformations. Returns: Tuple[Dict[int, Tuple[str]], Dict[int, int]]: A tuple containing two elements: - A dictionary mapping token indices to a tuple of their corresponding transformation types. - A dictionary of the original token lengths (currently empty in this implementation). Transformation Types: - **IDT**: Identical character (no transformation). - **PTC**: Prone to confusion (commonly confused characters). - **SAP**: Same pinyin (characters that share the same pinyin). - **SIP**: Similar pinyin (characters with similar pinyin). - **SIS**: Similar shape (characters with similar visual appearance). - **OTP**: Other pinyin error (pinyin-related errors not covered by SAP or SIP). - **OTS**: Other similar shape (shape-related errors not covered by SIS). - **MIS**: Missing characters (characters that are missing from the observed sequence). - **RED**: Redundant characters (characters that are not needed in the observed sequence). - **UNR**: Unrecognized transformation (no known transformation type). Example: >>> transformer = TransformationType(vocab, is_bytes_level=False) >>> transformations, _ = transformer.get_transformation_type("你好") >>> print(transformations) {36371: ('IDT', 'OTP'), 8225: ('IDT', 'UNR'), ...} Note: - This method relies on prior methods such as `get_pinyin_data`, `handle_identical_characters`, and various handlers for specific distortion types. - The `distortion_type_priority_order` attribute determines the order in which distortion handlers are applied. """ # Initialize a default dictionary to hold transformations for each token index token_transformation = defaultdict(dict) # Initialize an empty dictionary for original token lengths (unused in current implementation) original_token_length = dict() # record the token indices with missing characters broken_token_transformation = set() # record the original token lengths for the tokens with missing characters original_token_length_for_broken = dict() # Initialize a variable to hold transformations for Out-Of-Vocabulary (OOV) characters oov_transformation = None # Handle Out-Of-Vocabulary characters when operating at the byte level if ( self.is_bytes_level and len(observed_sequence) > 0 and observed_sequence[0] not in self.char_set ): # Get OOV transformations for the observed sequence oov_transformation = self.handle_oov_characters(observed_sequence) # Retrieve pinyin data for the observed sequence token_pinyins, token_consonants = self.get_pinyin_data(observed_sequence) # Check if pinyin data retrieval was successful validated = token_pinyins is not None # Iterate over each character in the observed sequence for i in range(len(observed_sequence)): # Get the current character part_observed_sequence = observed_sequence[: i + 1] char_i = observed_sequence[i] # Handle identical characters (no transformation needed) self.handle_identical_characters(i, char_i, token_transformation) # Skip further processing if the character is punctuation or whitespace if self.is_punctuation_or_space(char_i): if 'RED' in self.distortion_type_priority_order: self.handle_redundant_before_punctuation_or_space(i, char_i, observed_sequence, token_transformation, original_token_length) self.handle_continuous_punctuation_or_space(i, observed_sequence, token_transformation) break # If pinyin data is not valid, skip further processing for this character if not validated: continue # Define a mapping of distortion types to their corresponding handler methods distortion_handlers = { "PTC": lambda: self.handle_prone_to_confusion(i, char_i, token_transformation), "SAP": lambda: self.handle_same_pinyin(i, token_pinyins, token_transformation), "ROR": lambda: self.handle_reorder_tokens(i, part_observed_sequence, token_transformation), "SIP": lambda: self.handle_similar_pinyin(i, token_pinyins, token_transformation), "SIS": lambda: self.handle_similar_shape(i, char_i, token_transformation), "IMP": lambda: self.handle_initial_pinyin_match(i, token_consonants, token_transformation), "OTP": lambda: self.handle_other_pinyin_error(i, token_pinyins, token_transformation), "RED": lambda: self.handle_redundant_character_inside_token(i, part_observed_sequence, token_transformation, original_token_length), "OTS": lambda: self.handle_other_similar_shape(i, char_i, token_transformation), } # Iterate over distortion types based on their priority order for distortion_type in self.distortion_type_priority_order: if distortion_type in distortion_handlers: # Invoke the handler function for the current distortion type distortion_handlers[distortion_type]() # Handle missing characters if 'MIS' in self.distortion_type_priority_order: self.handle_missing_characters(observed_sequence, broken_token_transformation, original_token_length_for_broken) # Handle redundant characters if 'RED' in self.distortion_type_priority_order: self.handle_redundant_characters(observed_sequence, token_transformation, original_token_length) # Filter the transformations to finalize the transformation types for each token new_transformation = self.filter_and_finalize_transformations(token_transformation, broken_token_transformation, original_token_length, original_token_length_for_broken) # Incorporate OOV transformations into the final transformation mapping, if any if oov_transformation is not None: new_transformation.update(oov_transformation) # If no transformations were found, handle the final OOV case if len(new_transformation) == 0: new_transformation = self.handle_final_oov(observed_sequence) # Ensure that the transformation mapping is not empty to avoid assertion errors assert len(new_transformation) > 0, f"No transformations found for sequence: '{observed_sequence}'" return new_transformation, original_token_length